import tensorflow as tf
from tensorflow.contrib.layers.python.layers import layers
from cifar10 import cifar10_input
import numpy as np

# 导入数据
batch_size = 128
data_dir = './cifar/cifar-10-batches-bin'

images_train, labels_train = cifar10_input.inputs(eval_data=False, data_dir=data_dir, batch_size=batch_size)
images_test, labels_test = cifar10_input.inputs(eval_data=True, data_dir=data_dir, batch_size=batch_size)

# 定义占位符
x = tf.placeholder(tf.float32, [None, 24, 24, 3])   # cifar data 的shape 为 24*24*3
y = tf.placeholder(tf.float32, [None, 10])   # 0~9 数字分类 => 10 类

x_image = tf.reshape(x, [-1, 24, 24, 3])
# 使用封装库
h_conv10 = layers.conv2d(x_image, 64, [5, 1], 1, 'SAME', activation_fn=tf.nn.relu)
h_conv1 = layers.conv2d(h_conv10, 64, [1, 5], 1, 'SAME', activation_fn=tf.nn.relu)
h_pool1 = layers.max_pool2d(h_conv1, [2, 2], stride=2, padding='SAME')
h_conv20 = layers.conv2d(h_pool1, 64, [5, 1], 1, 'SAME', activation_fn=tf.nn.relu)
h_conv2 = layers.conv2d(h_conv20, 64, [1, 5], 1, 'SAME', activation_fn=tf.nn.relu)
h_pool2 = layers.max_pool2d(h_conv2, [2, 2], stride=2, padding='SAME')

nt_hpool2 = layers.avg_pool2d(h_pool2, [6, 6], stride=6, padding='SAME')
nt_hpool_flat = tf.reshape(nt_hpool2, [-1, 64])

y_conv = layers.fully_connected(nt_hpool_flat, 10, activation_fn=tf.nn.softmax)

cross_entropy = -tf.reduce_sum(y*tf.log(y_conv))

train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))

# 启动训练
sess = tf.Session()
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess)   # images_train:  (128, 24, 24, 3)

# print('images_train: ', images_train.shape)
for i in range(2000):
    image_batch, label_batch = sess.run([images_train, labels_train])   # label_batch: [0 5 6 0 1 2 5 ... 2 0 7 3 7]
    label_b = np.eye(10, dtype=float)[label_batch]   # label 转成 one_hot 编码

    train_step.run(feed_dict={x: image_batch, y: label_b}, session=sess)

    if i % 200 == 0:
        train_accuracy = accuracy.eval(feed_dict={x: image_batch, y: label_b}, session=sess)
        print('step %d, train_accuracy %g' % (i, train_accuracy))


# 评估结果
image_batch, label_batch = sess.run([images_test, labels_test])
label_b = np.eye(10, dtype=float)[label_batch]
test_accuracy = accuracy.eval(feed_dict={x: image_batch, y: label_b}, session=sess)
print('Finished!!!, test_accuracy %g' % (i, train_accuracy))
